import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import os
import random
import logging
logger = logging.getLogger(__name__)

from .utils import CIFAR10_TRAIN_TRAINSFORM, CIFAR10_EVAL_TRAINSFORM, unpickle, get_unsupervised_transform, CIFAR10_NORMALIZE, other_class



class CIFAR10(torch.utils.data.Dataset):
    def __init__(self, train, data_path, unsupervised_transform=False):
        super(CIFAR10, self).__init__()
        if train:
            data_list, label_list = [], []
            for i in range(1, 6):
                d = unpickle(os.path.join(data_path, 'cifar-10-batches-py/data_batch_%d' % i))
                data_list.append(d[b'data'])
                label_list.append(d[b'labels'])
            self.data_list = np.concatenate(data_list, axis=0)
            self.label_list = np.concatenate(label_list, axis=0)
        else:
            d = unpickle(os.path.join(data_path, 'cifar-10-batches-py/test_batch'))
            self.data_list = d[b'data']
            self.label_list = d[b'labels']

        # self.transform = CIFAR10_EVAL_TRAINSFORM if not unsupervised_transform else get_unsupervised_transform(normalize=CIFAR10_NORMALIZE)
        if unsupervised_transform:
            self.transform = get_unsupervised_transform(normalize=CIFAR10_NORMALIZE)
        else:
            self.transform = CIFAR10_TRAIN_TRAINSFORM if train else CIFAR10_EVAL_TRAINSFORM

    def __getitem__(self, idx):
        data = self.data_list[idx]
        label = self.label_list[idx]
        image = Image.fromarray(data.reshape(3, 32, 32).transpose(1, 2, 0))
        image = self.transform(image)
        return image, label

    def __len__(self):
        return len(self.label_list)


class CIFAR10_TwoCrops(CIFAR10):
    def __init__(self, train, data_path, need_transform_=False):
        super(CIFAR10_TwoCrops, self).__init__(train, data_path)
        self.transform = get_unsupervised_transform(normalize=CIFAR10_NORMALIZE)
        self.transform_ = CIFAR10_EVAL_TRAINSFORM if need_transform_ else get_unsupervised_transform(normalize=CIFAR10_NORMALIZE)

    def __getitem__(self, idx):
        data = self.data_list[idx]
        label = self.label_list[idx]
        image = Image.fromarray(data.reshape(3, 32, 32).transpose(1, 2, 0))
        image1 = self.transform_(image)
        image2 = self.transform(image)
        return (image1, image2), label


class NoisyCIFAR10(CIFAR10):
    def __init__(self, train, data_path, unsupervised_transform=False, noise_rate=0.0, is_asym=False, seed=0):
        super(NoisyCIFAR10, self).__init__(train, data_path, unsupervised_transform=unsupervised_transform)
        np.random.seed(seed)
        if is_asym:
            # automobile <- truck, bird -> airplane, cat <-> dog, deer -> horse
            source_class = [9, 2, 3, 5, 4]
            target_class = [1, 0, 5, 3, 7]
            for s, t in zip(source_class, target_class):
                cls_idx = np.where(np.array(self.label_list) == s)[0]
                n_noisy = int(noise_rate * cls_idx.shape[0])
                noisy_sample_index = np.random.choice(cls_idx, n_noisy, replace=False)
                for idx in noisy_sample_index:
                    self.label_list[idx] = t
            return
        elif noise_rate > 0:
            n_samples = len(self.label_list)
            n_noisy = int(noise_rate * n_samples)
            logger.info("%d Noisy samples" % (n_noisy))
            class_index = [np.where(np.array(self.label_list) == i)[0] for i in range(10)]
            class_noisy = int(n_noisy / 10)
            noisy_idx = []
            for d in range(10):
                noisy_class_index = np.random.choice(class_index[d], class_noisy, replace=False)
                noisy_idx.extend(noisy_class_index)
                logger.info("Class %d, number of noisy %d" % (d, len(noisy_class_index)))
            for i in noisy_idx:
                self.label_list[i] = other_class(n_classes=10, current_class=self.label_list[i])
        logger.info("Pring noisy label generation statistics:")
        for i in range(10):
            n_noisy = np.sum(np.array(self.label_list) == i)
            logger.info("Noisy class %s, has %s samples." % (i, n_noisy))

class NoisyCIFAR10_TwoCrops(NoisyCIFAR10):
    def __init__(self, train, data_path, unsupervised_transform=False, need_transform_=False, noise_rate=0.0, is_asym=False):
        super(NoisyCIFAR10_TwoCrops, self).__init__(train, data_path, unsupervised_transform=unsupervised_transform, noise_rate=noise_rate, is_asym=is_asym)
        self.transform = get_unsupervised_transform(normalize=CIFAR10_NORMALIZE)
        self.transform_ = CIFAR10_EVAL_TRAINSFORM if need_transform_ else get_unsupervised_transform(normalize=CIFAR10_NORMALIZE)
    
    def __getitem__(self, idx):
        data = self.data_list[idx]
        label = self.label_list[idx]
        image = Image.fromarray(data.reshape(3, 32, 32).transpose(1, 2, 0))
        image1 = self.transform_(image)
        image2 = self.transform(image)
        return (image1, image2), label